package com.xiaofan.scala

import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.scala.{StreamExecutionEnvironment, _}
import org.apache.flink.util.Collector

object CountWindowAverage_D0001 {
  def main(args: Array[String]): Unit = {

    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
    val text: DataStream[(Long, Long)] = env.fromElements((1L, 3L), (1L, 5L), (1L, 7L), (1L, 4L), (1L, 2L))
    text.keyBy(_._1).flatMap(new CountWindowAverageFunction).print()

    env.execute("CountWindowAverage_D0001")
  }
}

class CountWindowAverageFunction extends RichFlatMapFunction[(Long, Long), (Long, Long)] {

  private var sum: ValueState[(Long, Long)] = _

  override def open(parameters: Configuration): Unit = {
    val descriptor: ValueStateDescriptor[(Long, Long)] = new ValueStateDescriptor[(Long, Long)]("average", TypeInformation.of(new TypeHint[(Long, Long)] {}))
    sum = getRuntimeContext.getState(descriptor)
  }

  override def flatMap(input: (Long, Long), out: Collector[(Long, Long)]): Unit = {
    var currentSum: (Long, Long) = sum.value()

    if (currentSum == null) {
      currentSum = (0L, 0L)
    }

    val newSum = (currentSum._1 + 1, currentSum._2 + input._2)
    sum.update(newSum)

    if (currentSum._1 >= 2) {
      out.collect((currentSum._1, currentSum._2 / currentSum._1))
      sum.clear()
    }
  }

}
